import math
from pathlib import Path

import pandas as pd
import statsmodels.formula.api as smf


MISSING_CODES = {7, 8, 9, 77, 88, 99, 777, 888, 999, 7777, 8888, 9999}


def clean_numeric(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    s = s.mask(s.isin(MISSING_CODES))
    return s


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    reg = pd.read_csv(root / "data_external" / "broadband_region_year_eurostat_isoc_r_broad_h.csv")

    # Focus on NUTS-like regions (regunit 1/2 in ESS), and years we actually fetched.
    ess = ess[ess["regunit"].isin([1, 2])].copy()
    ess["region"] = ess["region"].astype(str).str.upper().str.strip()
    ess = ess[ess["region"] != "99999"].copy()

    fetched_years = sorted(int(x) for x in reg["year"].dropna().unique().tolist())
    ess = ess[ess["survey_year"].isin(fetched_years)].copy()

    reg["region"] = reg["region"].astype(str).str.upper().str.strip()
    reg["year"] = pd.to_numeric(reg["year"], errors="coerce").astype("Int64")

    df = ess.merge(reg, left_on=["region", "survey_year"], right_on=["region", "year"], how="left")

    # Endogenous internet use vars
    df["netusoft"] = clean_numeric(df["netusoft"])
    df["netustm"] = clean_numeric(df["netustm"])
    df["log_netustm"] = (df["netustm"].clip(lower=0)).map(lambda x: math.log1p(x) if pd.notna(x) else x)

    # Instruments (region-year), standardized
    df["z_broadband_pc_hh"] = zscore(df["broadband_pc_hh"])
    df["z_internet_access_pc_hh"] = zscore(df["internet_access_pc_hh"])

    # Controls
    df["agea"] = clean_numeric(df["agea"])
    df["gndr"] = clean_numeric(df["gndr"])
    df["eduyrs"] = clean_numeric(df["eduyrs"])
    df["hinctnta"] = clean_numeric(df["hinctnta"])

    # Keep rows with at least one instrument present
    df = df[df[["z_broadband_pc_hh", "z_internet_access_pc_hh"]].notna().any(axis=1)].copy()

    out_path = root / "outputs" / "first_stage_region_year.txt"

    # First stage with region FE + country×year FE; clustered SE by region.
    # Note: For speed, we use OLS with dummies. This is a diagnostic, not final inference.
    formulas = [
        ("netusoft", "netusoft (ordinal numeric)"),
        ("log_netustm", "log(1+netustm)"),
    ]

    lines = []
    lines.append("First-stage diagnostics (region-year instruments from Eurostat isoc_r_broad_h)")
    lines.append(f"Years included: {fetched_years}")
    lines.append(f"Rows after merge/filter: {len(df):,}")
    lines.append("")

    def run(dep: str, spec: str, fml: str) -> list[str]:
        d = df[
            [
                "cntry",
                "survey_year",
                "region",
                dep,
                "z_broadband_pc_hh",
                "z_internet_access_pc_hh",
                "agea",
                "gndr",
                "eduyrs",
                "hinctnta",
            ]
        ].dropna()
        if d.empty:
            return [f"[SKIP] {spec}: no rows after dropna"]

        res = smf.ols(fml, data=d).fit(cov_type="cluster", cov_kwds={"groups": d["region"]})
        try:
            joint = res.f_test("z_broadband_pc_hh = 0, z_internet_access_pc_hh = 0")
            joint_str = f"joint_F={float(joint.fvalue):.2f} p={float(joint.pvalue):.4g}"
        except Exception:
            joint_str = "joint_F=(na)"

        return [
            f"[{spec}] n={len(d):,} regions={d['region'].nunique()} countries={d['cntry'].nunique()} years={sorted(d['survey_year'].unique().tolist())}",
            f"coef(z_broadband_pc_hh)={res.params.get('z_broadband_pc_hh', float('nan')):.4f} se={res.bse.get('z_broadband_pc_hh', float('nan')):.4f}",
            f"coef(z_internet_access_pc_hh)={res.params.get('z_internet_access_pc_hh', float('nan')):.4f} se={res.bse.get('z_internet_access_pc_hh', float('nan')):.4f}",
            joint_str,
        ]

    for dep, label in formulas:
        lines.append(f"== {label} ==")

        # SpecA: region FE + country×year FE (identification mainly from within-region changes)
        fml_a = (
            f"{dep} ~ z_broadband_pc_hh + z_internet_access_pc_hh + agea + gndr + eduyrs + hinctnta"
            " + C(region) + C(cntry):C(survey_year)"
        )
        lines.extend(run(dep, "SpecA regionFE + country×yearFE", fml_a))
        lines.append("")

        # SpecB: country×year FE only (identification from within-country-year regional differences)
        fml_b = (
            f"{dep} ~ z_broadband_pc_hh + z_internet_access_pc_hh + agea + gndr + eduyrs + hinctnta"
            " + C(cntry):C(survey_year)"
        )
        lines.extend(run(dep, "SpecB country×yearFE only", fml_b))
        lines.append("")

    out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
    print(f"Wrote: {out_path}")


if __name__ == "__main__":
    main()
